# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

""" Optimizer related helpers. """

from copy import deepcopy
from dataclasses import replace
from itertools import chain
from typing import Dict, List, Iterable

import torch

from .mapping import StateDict, ShardedStateDict, ShardedTensor, \
    LocalNonpersitentObject
from .dict_utils import nested_values


def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
    param_mappings = {}
    for i, param in enumerate(optim_params_iter):
        if id(param) not in param_mappings:
            param_mappings[id(param)] = i
    return param_mappings


def get_param_id_to_sharded_param_map(model_sharded_state_dict: ShardedStateDict,
                                      optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, ShardedTensor]:
    sharded_params = nested_values(model_sharded_state_dict)
    id_to_sharded_param_map = {}
    param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
    for ten in sharded_params:
        if id(ten.data) in param_to_id_map:
            id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
    return id_to_sharded_param_map



def make_sharded_optimizer_tensor(model_param: ShardedTensor, optim_param: torch.Tensor, prefix: str) -> ShardedTensor:
    assert tuple(optim_param.shape) == model_param.local_shape, \
        f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})'
    return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype)


def optim_state_to_sharding_state(optim_state_dict: StateDict, id_to_sharded_param_map: Dict[int, ShardedTensor]):
    sharded_state = {}
    for param_id, param_state in optim_state_dict['state'].items():
        sharded_state[param_id] = {}
        for state_key, param in param_state.items():
            if param_id in id_to_sharded_param_map:
                sharded_state[param_id][state_key] = make_sharded_optimizer_tensor(
                    id_to_sharded_param_map[param_id], param,
                    prefix=f'optimizer.state.{state_key}')
            else:
                raise ValueError(
                    f'Param id {param_id} does not match any model sharded param')

    optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups'])
    for group in optim_state_dict['param_groups']:
        group['params'] = LocalNonpersitentObject(group['params'])
    optim_state_dict['state'] = sharded_state
